
import os
dir_path = os.path.dirname(os.path.realpath(__file__))
os.chdir(dir_path)

os.environ['FOR_DISABLE_CONSOLE_CTRL_HANDLER'] = '1'

# +
import torch
from torch.utils.data import Dataset
from torch.distributions import MultivariateNormal
import numpy as np
from all_estimators import *
np.random.seed(42)
import random 
random.seed(42)
import argparse
from scipy import stats
import time

xx = time.time()


# -

def list_of_strings(arg):
    return arg.split(',')

parser = argparse.ArgumentParser()
parser.add_argument("N")
parser.add_argument("dim")
parser.add_argument("data_ID")
parser.add_argument("--tx", type=list_of_strings)
parser.add_argument("--ty", type=list_of_strings)

args = parser.parse_args()


import sys

orig_stdout = sys.stdout
f = open('Results_ID_0_MINE_only.txt', "a")
sys.stdout = f



# print(args.tx[0])
# print(args.ty)
# print(args)



def transform_points(data,transform):
    # print(transform)
    if transform == 'none':
        return data
    if transform == 'sigmoid':
        return torch.nn.functional.sigmoid(data)
    if transform == 'concat_self':
        params = params_dict[transform]
        data_orig = deepcopy(data) 
        for k in range(params[0]):
            data = torch.concatenate((data,data_orig),axis=1)
        return data
    if transform == 'concat_self_noisy':
        params = params_dict[transform]
        data_orig = deepcopy(data) 
        for k in range(params[0]):
            data = torch.concatenate((data,torch.rand_like(data_orig)*params[1]),axis=1)
        return data
    if transform == 'cube':
        data = data**3.0
        return data
    if transform == 'randmat':
        
        flag = 0
        while flag == 0:
            try:
                rand_mat = torch.rand(1)*torch.rand(data.shape[1],data.shape[1])
                torch.linalg.inv(rand_mat)
                flag = 1
            except:
                pass 
        return torch.matmul(data.double(), rand_mat.double())


class MultivariateNormalDataset(Dataset):
    def __init__(self, N, dim, rho,transforms_x=['none'],transforms_y=['none']):
        self.N = N
        self.rho = rho
        self.dim = dim
        # print(transforms_x)
        self.x_transforms = transforms_x 
        
        self.y_transforms = transforms_y
        
        self.dist = self.build_dist
        
        self.x = self.dist.sample((N, ))
        self.y = self.x[:,dim:]
        self.x = self.x[:,:dim]        
        
        self.transform_both() 
        # self.distractor_x = 
        self.dim = dim
        
    

    def __getitem__(self, ix):
        a, b = self.x[ix, 0:self.dim], self.x[ix, self.dim:2 * self.dim]
        return a, b
    
    def transform_both(self):
        for iter in range(len(self.y_transforms)):
            self.y = transform_points(self.y,self.y_transforms[iter])
        
        for iter in range(len(self.x_transforms)):
            self.x = transform_points(self.x,self.x_transforms[iter])
            
        self.x = self.x.numpy()
        self.y = self.y.numpy()

    def __len__(self):
        return self.N

    @property
    def build_dist(self):
        mu = torch.zeros(2 * self.dim)
        dist = MultivariateNormal(mu, self.cov_matrix)
        return dist

    @property
    def cov_matrix(self):
        cov = torch.zeros((2 * self.dim, 2 * self.dim))
        cov[torch.arange(self.dim), torch.arange(self.dim, 2 * self.dim)] = self.rho
        cov[torch.arange(self.dim, 2 * self.dim), torch.arange(self.dim)] = self.rho
        cov[torch.arange(2 * self.dim), torch.arange(2 * self.dim)] = 1.0
        return cov

    @property
    def true_mi(self):
        return -0.5 * np.log(np.linalg.det(self.cov_matrix.data.numpy()))

    


class GaussianAdditionDataset(Dataset):
    def __init__(self, N, dim, SNR,transforms_x=['none'],transforms_y=['none']):
        self.N = N
        self.dim = dim
        self.SNR = SNR
        
        self.x_transforms = transforms_x
        # self.x_transform_params = transforms_x[1]
        
        self.y_transforms = transforms_y 
        # self.y_transform_params = transforms_y[1]

        
        self.x = np.random.normal(0., 1, [self.N, self.dim])    
        self.y = self.x + np.random.normal(0., np.sqrt(1/self.SNR), [self.N, self.dim])
        self.x = torch.from_numpy(self.x)
        self.y = torch.from_numpy(self.y)
        
        self.transform_both() 
        # self.distractor_x = 
        
    
    def transform_both(self):
        for iter in range(len(self.y_transforms)):
            self.y = transform_points(self.y,self.y_transforms[iter])
        
        for iter in range(len(self.x_transforms)):
            self.x = transform_points(self.x,self.x_transforms[iter])
            
        self.x = self.x.numpy()
        self.y = self.y.numpy()

    def __len__(self):
        return self.N


    @property
    def true_mi(self):
        return self.dim*0.5*np.log(1 + self.SNR)


# Example usage:

params_dict = {
  "concat_self": [20],
  "randmat": [],
   "cube": [],
   "concat_self_noisy": [20,0.2],
   "sigmoid": [] 
}



total_epochs = 40
batch_size = 400
hidden_layer = 20
mine_est = MI_Estimator([total_epochs,batch_size,hidden_layer]).MINE_MI
mine_est_local = MI_Estimator([total_epochs,batch_size,hidden_layer]).MINE_Local_MI
mine_est_global = MI_Estimator([total_epochs,batch_size,hidden_layer,True]).MINE_Global_MI
mine_est_global_nocorrection = MI_Estimator([total_epochs,batch_size,hidden_layer,False]).MINE_Global_MI
# -----------------------------------


k1=3
c_local = [1.0]
c_global = np.linspace(0.1,2.0,20)
# c_global = [0.8,0.9,1.0,1.1,1.2]
# c_global = [1.0]
# print(C_z)
KSG_est = MI_Estimator([k1]).KSG
KSG_local_est = MI_Estimator([k1,c_local]).KSG_local
KSG_global_est = MI_Estimator([k1,c_global]).KSG_global
KSG_global_est_nomax = MI_Estimator([k1,[1.0]]).KSG_global
KSG_local_est_infnorm = MI_Estimator([k1,c_local]).KSG_local_infnorm
KSG_global_est_infnorm = MI_Estimator([k1,c_global]).KSG_global_infnorm

# -----------------------------------
# Mixed_est = MI_Estimator([k1]).Mixed_KSG 


k2 = 3
q = np.inf
revised_KSG_est = MI_Estimator([k2,q]).KSG_revised
# -----------------------------------
# k3 = 5
# alpha = 0.25
# LNC_est = MI_Estimator([k3,alpha]).LNC_MI 

bin_est = MI_Estimator([]).bin_MI 

# ----------------------------------
# infonce_est = MI_Estimator([100]).info_nce_MI
# infonce_est_local = MI_Estimator([100]).info_nce_MI_local
# infonce_est_global = MI_Estimator([100]).info_nce_MI_global


# hidden_ratio = [np.linspace(0.1,2.0,num=10)]
# hidden_ratio = np.arange(1,20)/50.0
# batch_size = 200
# MVIG_est = MI_Estimator([hidden_ratio,batch_size])
# VI_est = MI_Estimator([hidden_ratio[-1],batch_size])


# estimators = [KSG_est,KSG_local_est,KSG_global_est,KSG_local_est_infnorm,KSG_global_est_infnorm,revised_KSG_est,mine_est,mine_est_local,mine_est_global,mine_est_global_nocorrection,bin_est]
# estimators = [KSG_est,KSG_local_est_infnorm,KSG_global_est_infnorm]
estimators = [mine_est,mine_est_local,mine_est_global,mine_est_global_nocorrection]
# estimators = [KSG_est,KSG_local_est,KSG_global_est,KSG_local_est_infnorm,KSG_global_est_infnorm,revised_KSG_est,bin_est]

error_list = [[] for x in estimators]
output_list = [[] for x in estimators]
true_mi_list = [] 
one_sided_check_list = [[] for x in estimators]

# print('here')

N = int(args.N)
dim = int(args.dim)
trials = 40
transforms_x = args.tx
transforms_y = args.ty

# y_transforms = ['concat_self','cube','sigmoid']
# concat_num = 20
# concat_noise = 0.1
# y_transform_params = [[concat_num,concat_noise],[],[]]


# x_transforms = ['randmat','cube','sigmoid','concat_self']
# concat_num = 20
# x_transform_params = [[concat_num,concat_noise],[],[],[concat_num]]

datasets = [MultivariateNormalDataset,GaussianAdditionDataset]
data_ID = int(args.data_ID)
if data_ID == 0:
    rho_max = 0.8
else:
    rho_max = 2.0

ksg_error = [] 
ksg_local_error = []
ksg_global_error = []

for i in range(trials):    
    rho = np.random.rand()*rho_max
    dataset = datasets[data_ID](N, dim, rho,transforms_x,transforms_y)
    # print("True MI:", dataset.true_mi)
    # print('LNC:',estimators[-1](dataset.x,dataset.y))
    for temp in range(len(estimators)):
        E = estimators[temp](dataset.x,dataset.y)
        error_list[temp].append((E-dataset.true_mi))
        output_list[temp].append(E)
        one_sided_check_list[temp].append(int((E-dataset.true_mi)<0))
    
    true_mi_list.append(dataset.true_mi)


true_mi_list = np.array(true_mi_list)
permuted_mi_list = true_mi_list[np.random.permutation(len(true_mi_list))]
ref_error1 = np.mean(np.abs(permuted_mi_list - true_mi_list))
ref_error2 = np.sqrt(np.mean((permuted_mi_list - true_mi_list)**2))
# ref_error3 = np.mean(permuted_mi_list - true_mi_list)


for temp in range(len(error_list)):
    error_list[temp] = np.array(error_list[temp])


print('\n\n\n\n')
print(args)
print('\n')
for temp in range(len(error_list)):
    print(estimators[temp].__name__+"(MAE):   "+str(np.mean(np.abs(error_list[temp]))))
    print(estimators[temp].__name__+"(RMSE):   "+str(np.sqrt(np.mean(error_list[temp]**2))))
    print(estimators[temp].__name__+"(bias):   "+str(np.mean(error_list[temp])))
    print(estimators[temp].__name__+"(normalized MAE):   "+str(np.mean(np.abs(error_list[temp]))/ref_error1))
    print(estimators[temp].__name__+"(normalized RMSE):   "+str(np.sqrt(np.mean(error_list[temp]**2)/ref_error2)))
    print(estimators[temp].__name__+"(normalized bias):   "+str(np.mean(error_list[temp])/ref_error1))

    print(estimators[temp].__name__+"(Spearman):"+str(stats.spearmanr(np.array(output_list[temp]),np.array(true_mi_list))))
    print(estimators[temp].__name__+" Direction:",np.mean((np.array(one_sided_check_list[temp])==1)).astype(float))
    # error_list[temp] = np.mean(error_list[temp])


# a = input('done')

# +
yy = time.time() - xx
print('\n\n Time Taken: ',yy)

sys.stdout = orig_stdout
f.close()
# -


# print(error_list)
